# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Triton utilities for JAX primitives.

This module provides utility functions for integrating Triton kernels into
JAX primitives. Triton is only imported when this module is used.
"""

import hashlib
from typing import Any, Callable, Mapping
import zlib

from jax import core
import jax
import jax.numpy as jnp


try:
    from jax._src.lib import gpu_triton
    from triton.compiler import compiler as tc
    from triton.backends.nvidia import compiler as cb
    from triton.runtime import autotuner
except ImportError as e:
    raise ImportError(
        "Triton is required for transformer_engine.jax.triton_extensions. "
        "Install with: pip install triton\n"
        "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead."
    ) from e


__all__ = ["triton_call_lowering"]

# Triton kernel cache (module-level, shared across all kernels)
_TRITON_KERNEL_CACHE = {}


def get_triton_dtype(aval):
    """Convert JAX dtype to Triton type string.

    Args:
        aval: JAX ShapedArray

    Returns:
        Triton type string (e.g., "*fp32" for pointer, "i32" for scalar)
    """
    dtype_map = {
        jnp.dtype("bfloat16"): "bf16",
        jnp.dtype("float64"): "fp64",
        jnp.dtype("float32"): "fp32",
        jnp.dtype("float16"): "fp16",
        jnp.dtype("float8_e4m3fn"): "fp8e4nv",
        jnp.dtype("float8_e5m2"): "fp8e5",
        jnp.dtype("int64"): "i64",
        jnp.dtype("int32"): "i32",
        jnp.dtype("int16"): "i16",
        jnp.dtype("int8"): "i8",
        jnp.dtype("uint64"): "u64",
        jnp.dtype("uint32"): "u32",
        jnp.dtype("uint16"): "u16",
        jnp.dtype("uint8"): "u8",
        jnp.dtype("bool"): "i1",
    }

    assert isinstance(aval, core.ShapedArray), "aval must be a JAX ShapedArray"
    return f"*{dtype_map[aval.dtype]}"


def compile_triton(
    kernel_fn: Callable,
    signature: Mapping[str, str],
    constants: Mapping[str, Any],
    num_warps: int,
    num_stages: int,
    num_ctas: int,
    compute_capability: int,
    enable_fp_fusion: bool = False,
):
    """Compile a Triton kernel to PTX.

    Kernels are cached to avoid recompilation.

    Args:
        kernel_fn: Triton kernel function (decorated with @triton.jit)
        signature: Dict mapping arg names to types (e.g., {"x_ptr": "*fp32", "n": "i32"})
        constants: Dict of compile-time constants
        num_warps: Number of warps per block
        num_stages: Number of pipeline stages
        num_ctas: Number of CTAs (cooperative thread arrays)
        compute_capability: CUDA compute capability
        enable_fp_fusion: Enable FP fusion optimizations (default False for accuracy)

    Returns:
        TritonKernel object for JAX
    """
    # Create cache key
    cache_key = hashlib.md5(
        str(
            (
                kernel_fn.__name__,
                tuple(sorted(signature.items())),
                tuple(sorted(constants.items())),
                num_warps,
                num_stages,
                num_ctas,
                enable_fp_fusion,
                compute_capability,
            )
        ).encode()
    ).hexdigest()

    if cache_key in _TRITON_KERNEL_CACHE:
        return _TRITON_KERNEL_CACHE[cache_key]

    # Compile kernel
    options = cb.CUDAOptions(
        num_warps=num_warps,
        num_stages=num_stages,
        num_ctas=num_ctas,
        cluster_dims=(1, 1, 1),
        debug=False,
        enable_fp_fusion=enable_fp_fusion,
    )

    # Mark constants as constexpr in signature
    signature_with_constexpr = dict(signature)
    for const_name in constants.keys():
        if const_name in signature_with_constexpr:
            signature_with_constexpr[const_name] = "constexpr"

    src = tc.ASTSource(
        fn=kernel_fn,
        constexprs=constants,
        signature=signature_with_constexpr,
    )

    compiled = tc.compile(
        src,
        target=tc.GPUTarget("cuda", compute_capability, 32),
        options=options.__dict__,
    )

    # Create kernel object for JAX
    kernel = gpu_triton.TritonKernel(
        compiled.name,
        num_warps,
        compiled.metadata.shared,
        compiled.asm["ptx"],
        "",  # ttir
        compute_capability,
        1,
        1,
        1,  # cluster_dims
    )

    _TRITON_KERNEL_CACHE[cache_key] = kernel
    return kernel


def triton_call_lowering(
    ctx,
    kernel_fn: Callable,
    *array_args,
    grid,
    input_output_aliases: Mapping[int, int] = None,
    constexprs: Mapping[str, Any] = None,
):
    """Helper for MLIR lowering that calls a Triton kernel.

    Use this in your primitive's lowering method to call Triton kernels.

    Args:
        ctx: MLIR lowering context
        kernel_fn: Triton kernel function
        *array_args: Input arrays (from ctx)
        grid: Grid dimensions (int or tuple)
        input_output_aliases: Mapping of input to output aliases
        constexprs: Compile-time constants for the kernel

    Returns:
        MLIR lowering result

    Example:
        @staticmethod
        def lowering(ctx, x, *, block_size):
            from ..triton_extensions import triton_call_lowering
            n = ctx.avals_in[0].size
            return triton_call_lowering(
                ctx, my_kernel, x,
                grid=(triton.cdiv(n, block_size),),
                n_elements=n,
                BLOCK_SIZE=block_size
            )
    """
    # Get compute capability using gpu_triton
    compute_capability = gpu_triton.get_compute_capability(0)  # device 0

    # Build signature dict: map arg names to types
    # Get arg names from kernel function
    if isinstance(kernel_fn, autotuner.Autotuner):
        arg_names = kernel_fn.fn.arg_names
    else:
        arg_names = kernel_fn.arg_names

    # Build signature for inputs + outputs
    all_avals = list(ctx.avals_in) + list(ctx.avals_out)
    signature = {arg_names[i]: get_triton_dtype(aval) for i, aval in enumerate(all_avals)}

    # Normalize grid to 3D
    if isinstance(grid, int):
        grid_tuple = (grid, 1, 1)
    elif len(grid) == 1:
        grid_tuple = (grid[0], 1, 1)
    elif len(grid) == 2:
        grid_tuple = (grid[0], grid[1], 1)
    else:
        grid_tuple = grid[:3]

    # Default values for the kernel
    actual_kernel_fn = kernel_fn
    num_warps = 32
    num_stages = (
        1  # TODO(Phuong): consider if it is beneficial to expose num_warps, num_stages, num_ctas
    )
    num_ctas = 1
    kernel_constexprs = constexprs if constexprs is not None else {}

    # Handle autotuned kernels - compile all configs
    if isinstance(kernel_fn, autotuner.Autotuner):
        # Compile all configs for runtime selection
        kernel_calls = []
        actual_kernel_fn = kernel_fn.fn

        for config in kernel_fn.configs:
            # Extract parameters from config
            config_num_warps = config.num_warps if config.num_warps is not None else num_warps
            config_num_stages = config.num_stages if config.num_stages is not None else num_stages
            config_num_ctas = config.num_ctas if config.num_ctas is not None else num_ctas

            # Merge config kwargs with user constexprs
            config_constexprs = {**config.kwargs, **(constexprs if constexprs else {})}

            # Compile this config
            config_kernel = compile_triton(
                actual_kernel_fn,
                signature,
                config_constexprs,
                config_num_warps,
                config_num_stages,
                config_num_ctas,
                compute_capability,
                enable_fp_fusion=False,
            )

            # Create kernel call for this config
            config_params = []
            for _ in list(ctx.avals_in) + list(ctx.avals_out):
                config_params.append(gpu_triton.create_array_parameter(0, 16))

            config_call = gpu_triton.TritonKernelCall(
                config_kernel,
                grid_tuple[0],
                grid_tuple[1],
                grid_tuple[2],
                config_params,
            )

            kernel_calls.append((config_call, str(config)))

        # Create autotuned kernel call
        # Convert input_output_aliases to format with sizes
        if input_output_aliases is None:
            input_output_aliases = {}

        input_output_aliases_with_sizes = tuple(
            (
                input_idx,
                output_idx,
                ctx.avals_in[input_idx].size * ctx.avals_in[input_idx].dtype.itemsize,
            )
            for input_idx, output_idx in input_output_aliases.items()
        )

        kernel_call = gpu_triton.TritonAutotunedKernelCall(
            f"{actual_kernel_fn.__name__}_autotuned",
            kernel_calls,
            input_output_aliases_with_sizes,
        )

    else:
        # Regular kernel: compile single config
        kernel = compile_triton(
            actual_kernel_fn,
            signature,
            kernel_constexprs,
            num_warps,
            num_stages,
            num_ctas,
            compute_capability,
            enable_fp_fusion=False,
        )

        kernel_params = []
        for _ in list(ctx.avals_in) + list(ctx.avals_out):
            kernel_params.append(gpu_triton.create_array_parameter(0, 16))

        kernel_call = gpu_triton.TritonKernelCall(
            kernel,
            grid_tuple[0],
            grid_tuple[1],
            grid_tuple[2],
            kernel_params,
        )

    serialized_metadata = b""
    call_proto = kernel_call.to_proto(actual_kernel_fn.__name__, serialized_metadata)

    if input_output_aliases is None:
        input_output_aliases = {}

    # Use JAX FFI lowering with compressed protobuf
    rule = jax.ffi.ffi_lowering(
        "triton_kernel_call",  # Custom call target registered in gpu_triton.py
        api_version=2,
        backend_config=zlib.compress(call_proto),
        operand_output_aliases=input_output_aliases,
    )

    return rule(ctx, *array_args)
